{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-05-10T03:17:15.673316400Z",
     "start_time": "2023-05-10T03:17:15.535320Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "T, H = 5, 4\n",
    "hs = np.random.randn(T, H)\n",
    "a = np.array([0.8, 0.1, 0.03, 0.05, 0.02])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T03:17:55.534942400Z",
     "start_time": "2023-05-10T03:17:55.525940200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 4)\n"
     ]
    }
   ],
   "source": [
    "ar = a.reshape(5, 1).repeat(4, axis=1)\n",
    "print(ar.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T03:18:37.759725300Z",
     "start_time": "2023-05-10T03:18:37.685730Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "t = hs * ar"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T03:18:51.571968100Z",
     "start_time": "2023-05-10T03:18:51.562966Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "data": {
      "text/plain": "(5, 4)"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.shape"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T03:18:56.731604100Z",
     "start_time": "2023-05-10T03:18:56.703575200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4,)\n"
     ]
    }
   ],
   "source": [
    "c = np.sum(t, axis=0)\n",
    "print(c.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T03:19:13.528535800Z",
     "start_time": "2023-05-10T03:19:13.519534600Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "设时序数据的长度 T=5，隐藏状态向量的元素个数 H=4，这里给出了加权和的计算过程。我们先关注代码 ar = a.reshape(5, 1).repeat(4, axis=1)。\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "N, T, H = 10, 5, 4\n",
    "hs = np.random.randn(N, T, H)\n",
    "a = np.random.randn(N, T)\n",
    "ar = a.reshape(N, T, 1).repeat(H, axis=2)\n",
    "# ar = a.reshape(N, T, 1) # 使用广播"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T03:45:32.096045300Z",
     "start_time": "2023-05-10T03:45:32.090045Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 5, 4)\n"
     ]
    }
   ],
   "source": [
    "t = hs * ar\n",
    "print(t.shape)\n",
    "# (10, 5, 4)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T03:45:39.786891800Z",
     "start_time": "2023-05-10T03:45:39.776892400Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 4)\n"
     ]
    }
   ],
   "source": [
    "c = np.sum(t, axis=1)\n",
    "print(c.shape)\n",
    "# (10, 4)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T03:45:53.657794200Z",
     "start_time": "2023-05-10T03:45:53.633794500Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "class WeightSum:\n",
    "    def __init__(self):\n",
    "        self.params, self.grads = [], []\n",
    "        self.cache = None\n",
    "\n",
    "    def forward(self, hs, a):\n",
    "        N, T, H = hs.shape\n",
    "\n",
    "        ar = a.reshape(N, T, 1).repeat(H, axis=2)\n",
    "        t = hs * ar\n",
    "        c = np.sum(t, axis=1)\n",
    "\n",
    "        self.cache = (hs, ar)\n",
    "        return c\n",
    "\n",
    "    def backward(self, dc):\n",
    "        hs, ar = self.cache\n",
    "        N, T, H = hs.shape\n",
    "        dt = dc.reshape(N, 1, H).repeat(T, axis=1)  # sum的反向传播\n",
    "        dar = dt * hs\n",
    "        dhs = dt * ar\n",
    "        da = np.sum(dar, axis=2)  # repeat的反向传播\n",
    "        return dhs, da"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T04:48:15.890192400Z",
     "start_time": "2023-05-10T04:48:15.876193800Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 8.1.4 解码器的改进②"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('..')\n",
    "from common.layers import Softmax\n",
    "import numpy as np"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T04:48:56.291474900Z",
     "start_time": "2023-05-10T04:48:56.270474800Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "N, T, H = 10, 5, 4\n",
    "hs = np.random.randn(N, T, H)\n",
    "h = np.random.randn(N, H)\n",
    "hr = h.reshape(N, 1, H).repeat(T, axis=1)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T04:49:17.599179Z",
     "start_time": "2023-05-10T04:49:17.586181Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 5, 4)\n"
     ]
    }
   ],
   "source": [
    "t = hs * hr\n",
    "print(t.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T04:49:24.916379300Z",
     "start_time": "2023-05-10T04:49:24.907380Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 5)\n"
     ]
    }
   ],
   "source": [
    "s = np.sum(t, axis=2)\n",
    "print(s.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T04:49:29.674692500Z",
     "start_time": "2023-05-10T04:49:29.669688900Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 5)\n"
     ]
    }
   ],
   "source": [
    "softmax = Softmax()\n",
    "a = softmax.forward(s)\n",
    "print(a.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T04:49:37.098099800Z",
     "start_time": "2023-05-10T04:49:37.088099400Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [],
   "source": [
    "from common.np import *  # import numpy as np\n",
    "from common.layers import Softmax"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T04:56:36.124008300Z",
     "start_time": "2023-05-10T04:56:36.109008600Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [],
   "source": [
    "class AttentionWeight:\n",
    "    def __init__(self):\n",
    "        self.params, self.grads = [], []\n",
    "        self.softmax = Softmax()\n",
    "        self.cache = None\n",
    "\n",
    "    def forward(self, hs, h):\n",
    "        N, T, H = hs.shape\n",
    "        hr = h.reshape(N, 1, H).repeat(T, axis=1)\n",
    "        t = hs * hr\n",
    "        s = np.sum(t, axis=2)\n",
    "        a = self.softmax.forward(s)\n",
    "        self.cache = (hs, hr)\n",
    "        return a\n",
    "\n",
    "    def backward(self, da):\n",
    "        hs, hr = self.cache\n",
    "        N, T, H = hs.shape\n",
    "        ds = self.softmax.backward(da)\n",
    "        dt = ds.reshape(N, T, 1).repeat(H, axis=2)\n",
    "        dhs = dt * hr\n",
    "        dhr = dt * hs\n",
    "        dh = np.sum(dhr, axis=1)\n",
    "        return dhs, dh"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T05:13:35.258535400Z",
     "start_time": "2023-05-10T05:13:35.252534Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Attention Weight 层关注编码器输出的各个单词向量 hs，并计算各个单词的权重 a；然后，Weight Sum 层计算 a 和 hs 的加权和，并输出上下文向量 c。我们将进行这一系列计算的层称为 Attention 层\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [],
   "source": [
    "class Attention:\n",
    "    def __init__(self):\n",
    "        self.params, self.grads = [], []\n",
    "        self.attention_weight_layer = AttentionWeight()\n",
    "        self.weight_sum_layer = WeightSum()\n",
    "        self.attention_weight = None\n",
    "\n",
    "    def forward(self, hs, h):\n",
    "        a = self.attention_weight_layer.forward(hs, h)\n",
    "        out = self.weight_sum_layer.forward(hs, a)\n",
    "        self.attention_weight = a\n",
    "        return out\n",
    "\n",
    "    def backward(self, dout):\n",
    "        dhs0, da = self.weight_sum_layer.backward(dout)\n",
    "        dhs1, dh = self.attention_weight_layer.backward(da)\n",
    "        dhs = dhs0 + dhs1\n",
    "        return dhs, dh"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T05:42:55.960071Z",
     "start_time": "2023-05-10T05:42:55.947069100Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [],
   "source": [
    "# Time Attention 层只是组合了多个 Attention 层\n",
    "class TimeAttention:\n",
    "    def __init__(self):\n",
    "        self.params, self.grads = [], []\n",
    "        self.layers = None\n",
    "        self.attention_weights = None\n",
    "\n",
    "    def forward(self, hs_enc, hs_dec):\n",
    "        N, T, H = hs_dec.shape\n",
    "        out = np.empty_like(hs_dec)\n",
    "        self.layers = []\n",
    "        self.attention_weights = []\n",
    "        for t in range(T):\n",
    "            layer = Attention()\n",
    "            out[:, t, :] = layer.forward(hs_enc, hs_dec[:, t, :])\n",
    "            self.layers.append(layer)\n",
    "            self.attention_weights.append(layer.attention_weight)\n",
    "        return out\n",
    "\n",
    "    def backward(self, dout):\n",
    "        N, T, H = dout.shape\n",
    "        dhs_enc = 0\n",
    "        dhs_dec = np.empty_like(dout)\n",
    "        for t in range(T):\n",
    "            layer = self.layers[t]\n",
    "            dhs, dh = layer.backward(dout[:, t, :])\n",
    "            dhs_enc += dhs\n",
    "            dhs_dec[:, t, :] = dh\n",
    "        return dhs_enc, dhs_dec"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-10T05:45:40.898806200Z",
     "start_time": "2023-05-10T05:45:40.879807800Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "这里仅创建必要数量的 Attention 层（代码中为 T 个），各自进行正向传播和反向传播。另外，attention_weights 列表中保存了各个Attention 层对各个单词的权重。\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
